library(cregg)
library(dplyr)
library(ggplot2)
library(stringr)
library(formula.tools)
library(stringr)
library(patchwork)
library(grDevices)


fix_duplicate_level_names <- function(cregg_function_input, formula_to_be_used){
  # the cregg package does not accept attributes with the same level name. 
  # this function temporarily adds a special character to the duplicate level names
  
  variables_to_be_used <- rhs.vars(formula_to_be_used)
  variables_to_be_used <- str_remove_all(variables_to_be_used, '`')
  
  variables_used_in_conjoint <- dplyr::select(survey_data, all_of(variables_to_be_used))
  
  observed_level_names <- c()
  
  for(variable in variables_to_be_used){
    
    attribute <- variables_used_in_conjoint[, variable]
    level_names <- levels(attribute)
    
    level_names_in_other_attribute_mask <- level_names %in% observed_level_names
    level_names_in_other_attribute <- level_names[level_names_in_other_attribute_mask]
    
    if (length(level_names_in_other_attribute) > 0) {
      
      corrected_name <- paste0(level_names_in_other_attribute, "@")
      names(corrected_name) <- level_names_in_other_attribute
      
      attribute <- recode_factor(attribute, !!!corrected_name)
      
    }
    
    previously_unobserved_lvl_names <- level_names[!level_names_in_other_attribute_mask]
    observed_level_names <- c(observed_level_names, previously_unobserved_lvl_names)
    
    survey_data[, variable] <- attribute  
    
  }
  
  return(survey_data)
}

fix_level_name_in_cregg_output <- function(cregg_output){
  
  cregg_output <- cregg_output %>% 
    mutate(level = stringr::str_remove_all(string = as.character(level), pattern = "@"))
  cregg_output <- cregg_output %>% 
    mutate(level = as.factor(level))
}


make_vote_choice_plot <- function(cregg_output, 
                                  show_only_relevant_factors = TRUE){
  relevant_factors <- c("Father's Job", "University", "Occupation", "High School")
  
  if (show_only_relevant_factors){
    cregg_output <- cregg_output %>% 
      filter(feature %in% relevant_factors) 
    plot_limits <- c(0.4, 0.6)
    y_label_size <- element_text(size = 8)
    
  } else {
    # need larger limits otherwise some policy items are not shown
    plot_limits <- c(0.3, 0.7)
    
    # smaller text to avoid overlaps on policy items
    y_label_size <- element_text(size = 6)
  } 
  
  
  p <- cregg_output %>% 
    ggplot(aes(x = estimate, y = level)) + 
    geom_point() + 
    geom_errorbar(aes(xmin = lower, xmax = upper), width = 0) + 
    facet_wrap(~feature, dir = "v", strip.position = "top", scales = "free_y", ncol = 1) +
    scale_x_continuous(name = "Vote choice Marginal Means", 
                       limits =  plot_limits,
                       breaks = seq(0, 1, by = 0.05)) + 
    geom_vline(xintercept = 0.5, linetype = "dashed") +
    theme_bw() +
    theme(axis.text.y = y_label_size) + 
    theme(legend.position = "none")
  
  return(p)
  
}

fix_outcome_var_names_for_plot <- function(){
  good_region_output <- good_region_output %>% 
    mutate(outcome_var = "Good for the Region")
  good_country_output <- good_country_output %>% 
    mutate(outcome_var = "Good for the Country")
}

make_local_national_comparison_plot <- function(cregg_output_region, 
                                                cregg_output_national,
                                                show_only_relevant_factors = TRUE){
  relevant_factors <- c("Father's Job", "University", "Occupation", "High School")
  
  if (show_only_relevant_factors){
    cregg_output_region <- cregg_output_region %>% 
      filter(feature %in% relevant_factors) 
    
    cregg_output_national <- cregg_output_national %>% 
      filter(feature %in% relevant_factors)     
    plot_limits <- c(0.4, 0.6)
    y_label_size <- element_text(size = 8)
    
  } else {
    # need larger limits otherwise some policy items are not shown
    plot_limits <- c(0.3, 0.7)
    
    # smaller text to avoid overlaps on policy items
    y_label_size <- element_text(size = 6)
  }
  
  paper_plot_panel_1 <- cregg_output_region %>% 
    ggplot(aes(x = estimate, y = level)) + 
    geom_point() + 
    geom_errorbar(aes(xmin = lower, xmax = upper), width = 0) + 
    facet_wrap(~feature, dir = "v", strip.position = "top", scales = "free_y", ncol = 1) +
    scale_x_continuous(name = "Vote choice Marginal Means", 
                       limits = plot_limits,
                       breaks = seq(0, 1, by = 0.05)) + 
    geom_vline(xintercept = 0.5, linetype = "dashed") +
    ggtitle("Good for Region MM") + 
    theme_bw() +
    theme(axis.text.y = y_label_size) +
    theme(legend.position = "none")
  
  paper_plot_panel_2 <- cregg_output_national %>% 
    ggplot(aes(x = estimate, y = level)) + 
    geom_point() + 
    geom_errorbar(aes(xmin = lower, xmax = upper), width = 0) + 
    facet_wrap(~feature, dir = "v", strip.position = "top", scales = "free_y", ncol = 1) +
    scale_x_continuous(name = "Vote choice Marginal Means", 
                       limits = plot_limits,
                       breaks = seq(0, 1, by = 0.05)) + 
    geom_vline(xintercept = 0.5, linetype = "dashed") +
    ggtitle("Good for Country MM") + 
    theme_bw() +
    theme(axis.title.y = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks.y = element_blank(),
          legend.position = "none")
  
  
  paper_plot_combined <- paper_plot_panel_1 + paper_plot_panel_2
  
  return(paper_plot_combined)
  
}

path_to_data <- "data/first_survey_formatted.rds"

path_to_figures <- "figures"
paper_figures_folder <- "paper_plots"
appendix_figures_folder <- "appendix_plots"

vote_choice_outcome_plot <- "no_satisficer_vote_choice_mm.eps"
comparison_plot <- "no_satisficer_region_nation_comparison_mm.eps"
duration_plot <- "survey_duration.eps"

survey_data <- readRDS(path_to_data)
survey_data$duration <- as.numeric(survey_data$duration)

# Make plot of response duration
survey_duration_plot <- survey_data %>% 
  distinct(ResponseId, .keep_all = TRUE) %>% 
  filter(duration < 3000) %>% # remove observations with overly long response time
  ggplot(aes(x = as.numeric(duration))) + 
  geom_histogram(alpha = 0.8, bins = 30) + 
  scale_x_continuous(name = "Survey Response Duration in Seconds", breaks = seq(0, 4000, 500)) + 
  scale_y_continuous(name = "Number of Respondents") + 
  theme_bw()

individual_respondents <- survey_data %>% 
  distinct(ResponseId, task_number, .keep_all = TRUE) %>% 
  group_split(ResponseId)

median(survey_data$duration)

# identify respondents that picked 1 or 0 for all 8 conjoint tasks
satisficer_behavior_vote_choice <- sapply(individual_respondents, function(x) all(x$vote_choice)|!any(x$vote_choice))
satisficer_behavior_region <- sapply(individual_respondents, function(x) all(x$good_region)|!any(x$good_region))
satisficer_behavior_country <- sapply(individual_respondents, function(x) all(x$good_country)|!any(x$good_country))

selected_all_same_answer_mask <- satisficer_behavior_vote_choice|satisficer_behavior_region|satisficer_behavior_country

ids <- unname(unlist(lapply(individual_respondents, function(x) x %>% distinct(ResponseId))))
ids_all_same_answer <- as.data.frame(cbind(ids, selected_all_same_answer_mask))

survey_data <- survey_data %>% left_join(ids_all_same_answer, by = c("ResponseId" = "ids"))
survey_data$selected_all_same_answer_mask <- as.logical(survey_data$selected_all_same_answer_mask)

# remove observations that answered the survey too quickly
survey_data_no_satisficer <- survey_data %>% 
  filter(duration > 300) %>% 
  filter(!selected_all_same_answer_mask)
  
number_satisficers <- nrow(survey_data)/16 - nrow(survey_data_no_satisficer)/16 # 8 tasks * 2 profiles = 16

good_region_formula <- 
  good_region ~ `Father's Job` + `Incumbent Status` + Gender + `Years in Employment` + University + Occupation + `High School` + `Policy Position`
good_country_formula <- 
  good_country ~ `Father's Job` + `Incumbent Status` + Gender + `Years in Employment` + University + Occupation + `High School` + `Policy Position`
vote_choice_formula <- 
  vote_choice ~ `Father's Job` + `Incumbent Status` + Gender + `Years in Employment` + University + Occupation + `High School` + `Policy Position`


survey_data_no_satisficer <- fix_duplicate_level_names(survey_data_no_satisficer, good_region_formula)


vote_choice_output <- cregg::cj(
  formula = vote_choice_formula,
  data = survey_data_no_satisficer,
  id = ~ResponseId,
  estimate = "mm"
)

good_region_output <- cregg::cj(
  formula = good_region_formula,
  data = survey_data_no_satisficer,
  id = ~ResponseId,
  estimate = "mm"
)

good_country_output <- cregg::cj(
  formula = good_country_formula,
  data = survey_data_no_satisficer,
  id = ~ResponseId,
  estimate = "mm"
)

vote_choice_output <- fix_level_name_in_cregg_output(vote_choice_output)
good_region_output <- fix_level_name_in_cregg_output(good_region_output)
good_country_output <- fix_level_name_in_cregg_output(good_country_output)

fix_outcome_var_names_for_plot()


vote_choice_paper_plot <- make_vote_choice_plot(vote_choice_output)
region_national_comparion_paper_plot <- make_local_national_comparison_plot(
  good_region_output, good_country_output)

ggsave(
  filename = file.path(path_to_figures, appendix_figures_folder, vote_choice_outcome_plot),
  plot = vote_choice_paper_plot,
  height = 8, 
  width = 6,
  units = "in",
  device = cairo_ps,
  fallback_resolution = 200)

ggsave(
  filename = file.path(path_to_figures, appendix_figures_folder, comparison_plot),
  plot = region_national_comparion_paper_plot,
  height = 9.1, 
  width = 10,
  units = "in",
  device = cairo_ps,
  fallback_resolution = 200)

ggsave(
  filename = file.path(path_to_figures, appendix_figures_folder, duration_plot),
  plot = survey_duration_plot,
  height = 6, 
  width = 6,
  units = "in",
  device = cairo_ps,
  fallback_resolution = 200)




